import torch
import torch_npu
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
                                                                                                                 
# 检查是否有可用的 GPU
if torch.npu.is_available():
    print("NPU is available.")
    # 指定使用 npu:1
    device_id = 1
    torch.npu.set_device(device_id)                                                                                                 
    device = torch.device('npu:1')
                                                                                                                 
    # torch_dtype = torch.bfloat16  # 推荐使用 BF16
    torch_dtype = torch.float16  # 降低精度
    print(f"Using data type: {torch_dtype}")
                                                                                                                 
    # 创建一个随机数生成器
    generator = torch.Generator(device=device).manual_seed(42)
    print("Generator created")                                                                                                                 
    prompt = "A panda sits on a wooden stool in a serene bamboo forest. \
            The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. "
    print("Prompt set")
                                                                                                                 
    pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b",torch_dtype=torch_dtype)
    print("Pipeline loaded")
                                                                                                                 
    # 确保模型和数据都在选定的设备上
    pipe.to(device)
    print("Model moved to device")
    print(f"Current NPU device: {torch.npu.current_device()}")
                                                                                                                 
    video = pipe(prompt=prompt,num_videos_per_prompt=1,num_inference_steps=50,num_frames=24,\
                guidance_scale=6,generator=generator,).frames[0]
    print("Video generated")
                                                                                                                 
    export_to_video(video, "output.mp4", fps=8)
    print("Video exported")
                                                                                                                 
else:
    print("NPU is not available.")